{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import pandas as pd\n",
    "import yaml\n",
    "from IPython.core.display import HTML\n",
    "from IPython.display import display\n",
    "import torch\n",
    "import random\n",
    "\n",
    "from oml.lightning.pipelines.validate import extractor_validation_pipeline\n",
    "from oml.lightning.callbacks.metric import MetricValCallback\n",
    "\n",
    "display(HTML(\"<style>.container { width:100% !important; }</style>\"))\n",
    "pd.set_option('display.max_rows', 330)\n",
    "\n",
    "%matplotlib inline\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "cfg = f\"\"\"\n",
    "    accelerator: gpu\n",
    "    precision: 32\n",
    "    devices: 1\n",
    "\n",
    "    dataset_root: /path/to/dataset\n",
    "    dataframe_name: df.csv\n",
    "    bs_val: 128\n",
    "    num_workers: 10\n",
    "\n",
    "    transforms_val:\n",
    "      name: norm_resize_hypvit_torch\n",
    "      args:\n",
    "        im_size: 224\n",
    "        crop_size: 224\n",
    "\n",
    "    model:\n",
    "      name: vit\n",
    "      args:\n",
    "        arch: vits16\n",
    "        normalise_features: True\n",
    "        use_multi_scale: False\n",
    "        weights: /path/to/extractor.ckpt\n",
    "\n",
    "    metric_args:\n",
    "      cmc_top_k: [1, 10, 20, 30, 100]\n",
    "      map_top_k: [5, 10]\n",
    "\n",
    "\"\"\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "trainer, ret_dict = extractor_validation_pipeline(yaml.load(cfg, Loader=yaml.Loader));\n",
    "clb_metric = [x for x in trainer.callbacks if isinstance(x, MetricValCallback)][0]\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "cfg_p =  cfg + f\"\"\"\n",
    "    postprocessor:\n",
    "      name: pairwise_reranker\n",
    "      args:\n",
    "        top_n: 5\n",
    "        pairwise_model:\n",
    "          name: concat_siamese\n",
    "          args:\n",
    "            mlp_hidden_dims: [192]\n",
    "            weights: /path/to/postprocessor.ckpt\n",
    "            extractor:\n",
    "              name: vit\n",
    "              args:\n",
    "                arch: vits16\n",
    "                normalise_features: False\n",
    "                use_multi_scale: False\n",
    "                weights: null\n",
    "        num_workers: 10\n",
    "        batch_size: 128\n",
    "        verbose: True\n",
    "\n",
    "\"\"\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "trainer_p, ret_dict_p = extractor_validation_pipeline(yaml.load(cfg_p, Loader=yaml.Loader));\n",
    "clb_metric_p = [x for x in trainer_p.callbacks if isinstance(x, MetricValCallback)][0]\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "cmc_1 = clb_metric.metric.metrics_unreduced[\"OVERALL\"][\"cmc\"][1]\n",
    "cmc_1_p = clb_metric_p.metric.metrics_unreduced[\"OVERALL\"][\"cmc\"][1]\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Let's visualize the cases where postprocessor has improved the desired metric:\n",
    "\n",
    "ids = random.sample(torch.nonzero(cmc_1_p > cmc_1).squeeze().tolist(), 10)\n",
    "\n",
    "for idx in ids:\n",
    "    fig = clb_metric.metric.get_plot_for_queries([idx], n_instances=4, verbose=False)\n",
    "    fig = clb_metric_p.metric.get_plot_for_queries([idx], n_instances=4, verbose=False)\n"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
